{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Object Detection with SSD\n",
    "### Here we demostrate detection on example images using SSD with PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.backends.cudnn as cudnn\n",
    "from torch.autograd import Variable\n",
    "import numpy as np\n",
    "import cv2\n",
    "if torch.cuda.is_available():\n",
    "    torch.set_default_tensor_type('torch.cuda.FloatTensor')\n",
    "\n",
    "from ssd import build_ssd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build SSD300 in Test Phase\n",
    "1. Build the architecture, specifyingsize of the input image (300),\n",
    "    and number of object classes to score (21 for VOC dataset)\n",
    "2. Next we load pretrained weights on the VOC0712 trainval dataset  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "net = build_ssd('test', 300, 21)    # initialize SSD\n",
    "net.load_weights('../weights/ssd300_VOC_28000.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Image \n",
    "### Here we just load a sample image from the VOC07 dataset "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image = cv2.imread('./data/example.jpg', cv2.IMREAD_COLOR)  # uncomment if dataset not downloaded\n",
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "from data import VOCDetection, VOC_ROOT, VOCAnnotationTransform\n",
    "# here we specify year (07 or 12) and dataset ('test', 'val', 'train') \n",
    "testset = VOCDetection(VOC_ROOT, [('2007', 'val')], None, VOCAnnotationTransform())\n",
    "img_id = 60\n",
    "image = testset.pull_image(img_id)\n",
    "rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
    "# View the sampled input image before transform\n",
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(rgb_image)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-process the input.  \n",
    "#### Using the torchvision package, we can create a Compose of multiple built-in transorm ops to apply \n",
    "For SSD, at test time we use a custom BaseTransform callable to\n",
    "resize our image to 300x300, subtract the dataset's mean rgb values, \n",
    "and swap the color channels for input to SSD300."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = cv2.resize(image, (300, 300)).astype(np.float32)\n",
    "x -= (104.0, 117.0, 123.0)\n",
    "x = x.astype(np.float32)\n",
    "x = x[:, :, ::-1].copy()\n",
    "plt.imshow(x)\n",
    "x = torch.from_numpy(x).permute(2, 0, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SSD Forward Pass\n",
    "### Now just wrap the image in a Variable so it is recognized by PyTorch autograd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "xx = Variable(x.unsqueeze(0))     # wrap tensor in Variable\n",
    "if torch.cuda.is_available():\n",
    "    xx = xx.cuda()\n",
    "y = net(xx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parse the Detections and View Results\n",
    "Filter outputs with confidence scores lower than a threshold \n",
    "Here we choose 60% "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data import VOC_CLASSES as labels\n",
    "top_k=10\n",
    "\n",
    "plt.figure(figsize=(10,10))\n",
    "colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist()\n",
    "plt.imshow(rgb_image)  # plot the image for matplotlib\n",
    "currentAxis = plt.gca()\n",
    "\n",
    "detections = y.data\n",
    "# scale each detection back up to the image\n",
    "scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(2)\n",
    "for i in range(detections.size(1)):\n",
    "    j = 0\n",
    "    while detections[0,i,j,0] >= 0.6:\n",
    "        score = detections[0,i,j,0]\n",
    "        label_name = labels[i-1]\n",
    "        display_txt = '%s: %.2f'%(label_name, score)\n",
    "        pt = (detections[0,i,j,1:]*scale).cpu().numpy()\n",
    "        coords = (pt[0], pt[1]), pt[2]-pt[0]+1, pt[3]-pt[1]+1\n",
    "        color = colors[i]\n",
    "        currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=2))\n",
    "        currentAxis.text(pt[0], pt[1], display_txt, bbox={'facecolor':color, 'alpha':0.5})\n",
    "        j+=1"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
